Skip to content

Conversation

@gabeweisz
Copy link
Contributor

Description

When using THD format packed data with TransformerEngine, the user must specify the maximum number of segments that can be packed into a sequence at Jax JIT time. If grain packs more segments than allowed, then this can cause crashes or data corruption.

We have previously updated grain to allow limiting the number of segments to pack into a sequence, and this PR takes the appropriate value from the MaxText configuration and passes it to Grain

Tests

We have had this fix in place in our AMD fork of MaxText for some time, but needed to get the Grain fix upstreamed first before creating this PR.
We have tested this fix extensively internally and have customers using it in production.

MaxText does not currently have any tests that use packed batches, but I can create some if needed.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Contributor

@yeandy yeandy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to add max_sequences_per_bin=config.max_segments_per_seq in make_hf_eval_iterator too

@gabeweisz
Copy link
Contributor Author

We may need to add max_sequences_per_bin=config.max_segments_per_seq in make_hf_eval_iterator too

Done, thanks for the tip

@gabeweisz gabeweisz closed this Dec 4, 2025
@gabeweisz gabeweisz reopened this Dec 4, 2025
@aireenmei
Copy link
Collaborator

@gabeweisz
Copy link
Contributor Author

@gabeweisz gabeweisz requested a review from aireenmei December 18, 2025 12:18
Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@gabeweisz gabeweisz force-pushed the gw_plumb_max_sequences_per_seq_to_grain branch from 1c5b56e to d107901 Compare December 18, 2025 23:47
@gabeweisz
Copy link
Contributor Author

@aireenmei I apologize but I needed to make some updates for the code quality checks (they didn't run before your approval). Can you please take another look. These were minor edits.

@gabeweisz gabeweisz requested a review from aireenmei December 18, 2025 23:51
@aireenmei
Copy link
Collaborator

Could you take a look to see if this error is related: https://github.com/AI-Hypercomputer/maxtext/actions/runs/20354898076/job/58493556179?pr=2774

@aireenmei
Copy link
Collaborator

I've asked @SurbhiJainUSC to also take a look. @gabeweisz you will need to squash your commits into 1 before merging

@gabeweisz gabeweisz force-pushed the gw_plumb_max_sequences_per_seq_to_grain branch from d107901 to f7971f2 Compare December 19, 2025 16:47
@gabeweisz
Copy link
Contributor Author

@aireenmei thanks for the feedback - I have done the squash commit
@SurbhiJainUSC please take a look - I need a review from a reviewer with write access

@gabeweisz
Copy link
Contributor Author

Looks like I still need an approval from a maintainer. How can I get one?

@copybara-service copybara-service bot merged commit f0526a4 into AI-Hypercomputer:main Dec 30, 2025
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants